%load_ext autoreload
%autoreload
import pandas as pd
import numpy as np
import os
import sys
import re
import matplotlib.pyplot as plt
import datetime
import seaborn as sns
from sklearn.preprocessing import StandardScaler
from sklearn.feature_selection import SelectKBest, chi2
from sklearn.ensemble import ExtraTreesClassifier
from sklearn.cluster import KMeans
from sklearn.cluster import AffinityPropagation
from sklearn.cluster import SpectralClustering
from sklearn.cluster import AgglomerativeClustering
from sklearn.decomposition import PCA
from sklearn.metrics import pairwise_distances_argmin_min
from sklearn import metrics
%matplotlib inline
from mpl_toolkits.mplot3d import Axes3D
plt.rcParams['figure.figsize'] = (16, 9)
#plt.style.use('ggplot')
from itertools import cycle
import networkx as nx
from matplotlib_venn import venn2, venn3, venn3_unweighted
import venn
from lineage import get_lineage_coll, get_lineage_snp
from resistance import get_resistance_snp
from coinfection_functions import import_VCF42_to_pandas, filter_repeats, add_snp_distance, \
scatter_vcf_pos, distplot_sns, add_window_distance
pd.set_option('display.max_columns', None)
sample_list = []
vcf_files = []
directory = "somoza"
for file in os.listdir(directory):
if file.endswith(".snp.hf.pass.vcf"):
name = file.split(".")[0]
sample_list.append(name)
file_path = os.path.join(directory, file)
vcf_files.append(file_path)
sample_list.sort()
vcf_files.sort()
print(sample_list)
print(vcf_files)
x = datetime.datetime.now()
print(x)
dict_vcf = {}
for vcf, df_name in zip(vcf_files, sample_list):
dict_vcf[df_name] = import_VCF42_to_pandas(vcf)
y = datetime.datetime.now()
print(y)
print("Done with import in: %s" % (y - x))
dict_vcf[sample_list[2]].head()
dict_vcf[sample_list[2]].columns
#dict_vcf = {}
for name, vcf_df in dict_vcf.items():
#Add repeat info (Phage, Transposon or PE/PPE regions)
vcf_df['Is_repeat'] = vcf_df.apply(filter_repeats, axis=1)
#Add info of nearby positions
add_snp_distance(vcf_df)
#Add info of clustered positions in sliding window
add_window_distance(vcf_df, window_size=10)
dict_vcf[sample_list[2]].columns
dict_dff = {}
distance = 2
QD = 15
for name, df_vcf in dict_vcf.items():
list_positions_to_filter = df_vcf['POS'][(
(df_vcf.snp_left_distance <= distance)|
(df_vcf.snp_right_distance <= distance)|
(df_vcf.Window_10 >= 3)|
(df_vcf.Is_repeat == True) |
(df_vcf.AF <= 0.0) |
(df_vcf.len_AD > 2) |
(df_vcf.QD <= QD))].tolist()
dict_dff[name] = df_vcf[~df_vcf.POS.isin(list_positions_to_filter)]
dict_dff[sample_list[2]][dict_dff[sample_list[2]].Is_repeat == True].head()
dict_dff[sample_list[2]][dict_dff[sample_list[2]].snp_left_distance <= 30].head()
for name_dsf, df in dict_vcf.items():
scatter_vcf_pos(df,name_dsf)
for name_dsf, df in dict_dff.items():
scatter_vcf_pos(df,name_dsf)
for name_dsf, df in dict_dff.items():
distplot_sns(df,name_dsf)
def split_df_mean(vcf_df, homoz=0.95):
filter_homoz_top = homoz
filter_homoz_bottom = (1 - filter_homoz_top)
#mean = vcf_df['AF'][(vcf_df['AF'] <= filter_homoz_top) & (vcf_df['AF'] >= filter_homoz_bottom)].mean(axis = 0)
mean = vcf_df['AF'][(vcf_df.AC != 2) & (vcf_df.gt0 != 1)].mean(axis = 0)
top_positions = vcf_df['POS'][(vcf_df['AF'] >= mean) & (vcf_df.AC != 2) & (vcf_df.gt0 != 1)].tolist()
bottom_positions = vcf_df['POS'][(vcf_df['AF'] < mean)].tolist()
return top_positions, bottom_positions
dict_t_b = {}
for name_dff, df_dff in dict_dff.items():
name_top = name_dff + "_t"
name_btm = name_dff + "_b"
name_top , name_btm = split_df_mean(df_dff)
dict_t_b[name_dff] = name_top , name_btm
for k,v in dict_t_b.items():
print(k,len(v[0]), len(v[1]))
S2 = set(dict_dff['SOMOZACOL2'].POS.values)
S3 = set(dict_dff['SOMOZACOL3'].POS.values)
S3_exclusive = S3 - S2
S2_exclusive = S2 - S3
Shared_2_3 = S2 & S3
total_difference_2_3 = (len(S2_exclusive) + len(S3_exclusive))
print("S2 has %s unique positions\n\
S3 has %s unique positions\n\
Both share %s positions\n\
Total difference: %s" % (len(S2_exclusive) , len(S3_exclusive), len(Shared_2_3),total_difference_2_3 ))
plt.figure(figsize=(20,10))
plt.figure(figsize=(20,10))
plt.subplot(2, 3, 1)
plt.title("17152627")
v3 = venn3([set(dict_t_b['17152627'][0]), set(dict_t_b['17152627'][1]), S2_exclusive],
set_labels = ('TOP', 'BTM', 'S2'))
plt.subplot(2, 3, 2)
plt.title("17160843")
v3 = venn3([set(dict_t_b['17160843'][0]), set(dict_t_b['17160843'][1]), S2_exclusive],
set_labels = ('TOP', 'BTM', 'S2'))
plt.subplot(2, 3, 3)
plt.title("17171479")
v3 = venn3([set(dict_t_b['17171479'][0]), set(dict_t_b['17171479'][1]), S2_exclusive],
set_labels = ('TOP', 'BTM', 'S2'))
plt.suptitle('S2 vs S3', fontsize=16, verticalalignment='bottom')
plt.subplot(2, 3, 4)
plt.title("17152627")
v3 = venn3([set(dict_t_b['17152627'][0]), set(dict_t_b['17152627'][1]), S3_exclusive],
set_labels = ('TOP', 'BTM', 'S3'))
plt.subplot(2, 3, 5)
plt.title("17160843")
v3 = venn3([set(dict_t_b['17160843'][0]), set(dict_t_b['17160843'][1]), S3_exclusive],
set_labels = ('TOP', 'BTM', 'S3'))
plt.subplot(2, 3, 6)
plt.title("17171479")
v3 = venn3([set(dict_t_b['17171479'][0]), set(dict_t_b['17171479'][1]), S3_exclusive],
set_labels = ('TOP', 'BTM', 'S3'))
#plt.savefig(os.path.join('img', 'SOMOZA_ALL.svg'), format="svg")
plt.show()
plt.figure(figsize=(20,10))
plt.figure(figsize=(20,10))
plt.subplot(2, 3, 1)
plt.title("17152627")
v3 = venn3_unweighted([set(dict_t_b['17152627'][0]), set(dict_t_b['17152627'][1]), S2_exclusive],
set_labels = ('TOP', 'BTM', 'S2'))
plt.subplot(2, 3, 2)
plt.title("17160843")
v3 = venn3_unweighted([set(dict_t_b['17160843'][0]), set(dict_t_b['17160843'][1]), S2_exclusive],
set_labels = ('TOP', 'BTM', 'S2'))
plt.subplot(2, 3, 3)
plt.title("17171479")
v3 = venn3_unweighted([set(dict_t_b['17171479'][0]), set(dict_t_b['17171479'][1]), S2_exclusive],
set_labels = ('TOP', 'BTM', 'S2'))
plt.suptitle('S2 vs S3', fontsize=16, verticalalignment='bottom')
plt.subplot(2, 3, 4)
plt.title("17152627")
v3 = venn3_unweighted([set(dict_t_b['17152627'][0]), set(dict_t_b['17152627'][1]), S3_exclusive],
set_labels = ('TOP', 'BTM', 'S3'))
plt.subplot(2, 3, 5)
plt.title("17160843")
v3 = venn3_unweighted([set(dict_t_b['17160843'][0]), set(dict_t_b['17160843'][1]), S3_exclusive],
set_labels = ('TOP', 'BTM', 'S3'))
plt.subplot(2, 3, 6)
plt.title("17171479")
v3 = venn3_unweighted([set(dict_t_b['17171479'][0]), set(dict_t_b['17171479'][1]), S3_exclusive],
set_labels = ('TOP', 'BTM', 'S3'))
#plt.savefig(os.path.join('img', 'SOMOZA_ALL.svg'), format="svg")
plt.show()
dict_dff[sample_list[2]].describe()
dict_dff[sample_list[2]].hist()
plt.show()
def assign_group_somoza(row):
if row.POS in S2_exclusive:
return 1
elif row.POS in S3_exclusive:
return 2
else:
return 3
print(dict_dff.keys())
Ktest15 = dict_dff['17152627'][(dict_dff['17152627'].AC != 2) & (dict_dff['17152627'].gt0 != 1)]
Ktest16 = dict_dff['17160843'][(dict_dff['17160843'].AC != 2) & (dict_dff['17160843'].gt0 != 1)]
Ktest17 = dict_dff['17171479'][(dict_dff['17171479'].AC != 2) & (dict_dff['17171479'].gt0 != 1)]
Ktest15.head()
Ktest15['category'] = Ktest15.apply(assign_group_somoza, axis=1)
Ktest16['category'] = Ktest16.apply(assign_group_somoza, axis=1)
Ktest17['category'] = Ktest17.apply(assign_group_somoza, axis=1)
Ktest15[Ktest15.category == 3]
Ktest15 = Ktest15[Ktest15.category != 3]
Ktest16 = Ktest16[Ktest16.category != 3]
Ktest17 = Ktest17[Ktest17.category != 3]
Ktest15.gt0[Ktest15.category == 1].value_counts()
sns.pairplot(Ktest15, hue='category',size=5,vars=["AF","QD", "QUAL", "SOR","dp", "DP", "MQ","MQRankSum", "FS","GQ","BaseQRankSum", "ExcessHet","REF_AD", "ALT_AD", "ReadPosRankSum"],kind='scatter')
sns.pairplot(Ktest15, hue='category',size=5,vars=["AF","QD", "QUAL", "ExcessHet","REF_AD", "ALT_AD"],kind='scatter')
sns.pairplot(Ktest16, hue='category',size=5,vars=["AF","QD", "QUAL", "SOR","dp", "DP", "MQ","MQRankSum", "FS","GQ","BaseQRankSum", "ExcessHet","REF_AD", "ALT_AD", "ReadPosRankSum"],kind='scatter')
sns.pairplot(Ktest16,size=5,vars=["AF","QD", "QUAL", "SOR","dp", "DP", "MQ","MQRankSum", "FS","BaseQRankSum", "ExcessHet","REF_AD", "ALT_AD", "ReadPosRankSum"],diag_kind="kde",kind='scatter')
sns.pairplot(Ktest17, hue='category',size=5,vars=["AF","QD", "QUAL", "SOR","dp", "DP", "MQ","MQRankSum", "FS","GQ","BaseQRankSum", "ExcessHet","REF_AD", "ALT_AD", "ReadPosRankSum"],kind='scatter')